{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "\n",
    "class MaxEnt:\n",
    "    def __init__(self, max_iter=100):\n",
    "        # 训练输入\n",
    "        self.X_ = None\n",
    "        # 训练标签\n",
    "        self.y_ = None\n",
    "        # 标签类别数量\n",
    "        self.m = None   \n",
    "        # 特征数量\n",
    "        self.n = None   \n",
    "        # 训练样本量\n",
    "        self.N = None   \n",
    "        # 常数特征取值\n",
    "        self.M = None\n",
    "        # 权重系数\n",
    "        self.w = None\n",
    "        # 标签名称\n",
    "        self.labels = defaultdict(int)\n",
    "        # 特征名称\n",
    "        self.features = defaultdict(int)\n",
    "        # 最大迭代次数\n",
    "        self.max_iter = max_iter\n",
    "\n",
    "    ### 计算特征函数关于经验联合分布P(X,Y)的期望\n",
    "    def _EP_hat_f(self, x, y):\n",
    "        self.Pxy = np.zeros((self.m, self.n))\n",
    "        self.Px = np.zeros(self.n)\n",
    "        for x_, y_ in zip(x, y):\n",
    "            # 遍历每个样本\n",
    "            for x__ in set(x_):\n",
    "                self.Pxy[self.labels[y_], self.features[x__]] += 1\n",
    "                self.Px[self.features[x__]] += 1           \n",
    "        self.EP_hat_f = self.Pxy/self.N\n",
    "    \n",
    "    ### 计算特征函数关于模型P(Y|X)与经验分布P(X)的期望\n",
    "    def _EP_f(self):\n",
    "        self.EPf = np.zeros((self.m, self.n))\n",
    "        for X in self.X_:\n",
    "            pw = self._pw(X)\n",
    "            pw = pw.reshape(self.m, 1)\n",
    "            px = self.Px.reshape(1, self.n)\n",
    "            self.EP_f += pw*px / self.N\n",
    "    \n",
    "    ### 最大熵模型P(y|x)\n",
    "    def _pw(self, x):\n",
    "        mask = np.zeros(self.n+1)\n",
    "        for ix in x:\n",
    "            mask[self.features[ix]] = 1\n",
    "        tmp = self.w * mask[1:]\n",
    "        pw = np.exp(np.sum(tmp, axis=1))\n",
    "        Z = np.sum(pw)\n",
    "        pw = pw/Z\n",
    "        return pw\n",
    "\n",
    "    ### 熵模型拟合\n",
    "    ### 基于改进的迭代尺度方法IIS\n",
    "    def fit(self, x, y):\n",
    "        # 训练输入\n",
    "        self.X_ = x\n",
    "        # 训练输出\n",
    "        self.y_ = list(set(y))\n",
    "        # 输入数据展平后集合\n",
    "        tmp = set(self.X_.flatten())\n",
    "        # 特征命名\n",
    "        self.features = defaultdict(int, zip(tmp, range(1, len(tmp)+1)))   \n",
    "        # 标签命名\n",
    "        self.labels = dict(zip(self.y_, range(len(self.y_))))\n",
    "        # 特征数\n",
    "        self.n = len(self.features)+1  \n",
    "        # 标签类别数量\n",
    "        self.m = len(self.labels)\n",
    "        # 训练样本量\n",
    "        self.N = len(x)  \n",
    "        # 计算EP_hat_f\n",
    "        self._EP_hat_f(x, y)\n",
    "        # 初始化系数矩阵\n",
    "        self.w = np.zeros((self.m, self.n))\n",
    "        # 循环迭代\n",
    "        i = 0\n",
    "        while i <= self.max_iter:\n",
    "            # 计算EPf\n",
    "            self._EP_f()\n",
    "            # 令常数特征函数为M\n",
    "            self.M = 100\n",
    "            # IIS算法步骤(3)\n",
    "            tmp = np.true_divide(self.EP_hat_f, self.EP_f)\n",
    "            tmp[tmp == np.inf] = 0\n",
    "            tmp = np.nan_to_num(tmp)\n",
    "            sigma = np.where(tmp != 0, 1/self.M*np.log(tmp), 0)  \n",
    "            # 更新系数:IIS步骤(4)\n",
    "            self.w = self.w + sigma\n",
    "            i += 1\n",
    "        print('training done.')\n",
    "        return self\n",
    "\n",
    "    # 定义最大熵模型预测函数\n",
    "    def predict(self, x):\n",
    "        res = np.zeros(len(x), dtype=np.int64)\n",
    "        for ix, x_ in enumerate(x):\n",
    "            tmp = self._pw(x_)\n",
    "            print(tmp, np.argmax(tmp), self.labels)\n",
    "            res[ix] = self.labels[self.y_[np.argmax(tmp)]]\n",
    "        return np.array([self.y_[ix] for ix in res])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(105, 4) (105,)\n"
     ]
    }
   ],
   "source": [
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection import train_test_split\n",
    "raw_data = load_iris()\n",
    "X, labels = raw_data.data, raw_data.target\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, labels, test_size=0.3, random_state=43)\n",
    "print(X_train.shape, y_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([2, 2, 2, 2, 2])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels[-5:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Installation\\anaconda\\install\\lib\\site-packages\\ipykernel_launcher.py:90: RuntimeWarning: invalid value encountered in true_divide\n",
      "D:\\Installation\\anaconda\\install\\lib\\site-packages\\ipykernel_launcher.py:93: RuntimeWarning: divide by zero encountered in log\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training done.\n",
      "[0.87116843 0.04683368 0.08199789] 0 {0: 0, 1: 1, 2: 2}\n",
      "[0.00261138 0.49573305 0.50165557] 2 {0: 0, 1: 1, 2: 2}\n",
      "[0.12626693 0.017157   0.85657607] 2 {0: 0, 1: 1, 2: 2}\n",
      "[1.55221378e-04 4.45985560e-05 9.99800180e-01] 2 {0: 0, 1: 1, 2: 2}\n",
      "[7.29970746e-03 9.92687370e-01 1.29226740e-05] 1 {0: 0, 1: 1, 2: 2}\n",
      "[0.01343943 0.01247887 0.9740817 ] 2 {0: 0, 1: 1, 2: 2}\n",
      "[0.85166079 0.05241898 0.09592023] 0 {0: 0, 1: 1, 2: 2}\n",
      "[0.00371481 0.00896982 0.98731537] 2 {0: 0, 1: 1, 2: 2}\n",
      "[2.69340079e-04 9.78392776e-01 2.13378835e-02] 1 {0: 0, 1: 1, 2: 2}\n",
      "[0.01224702 0.02294254 0.96481044] 2 {0: 0, 1: 1, 2: 2}\n",
      "[0.00323508 0.98724246 0.00952246] 1 {0: 0, 1: 1, 2: 2}\n",
      "[0.00196548 0.01681989 0.98121463] 2 {0: 0, 1: 1, 2: 2}\n",
      "[0.00480966 0.00345107 0.99173927] 2 {0: 0, 1: 1, 2: 2}\n",
      "[0.00221101 0.01888735 0.97890163] 2 {0: 0, 1: 1, 2: 2}\n",
      "[9.87528545e-01 3.25313387e-04 1.21461416e-02] 0 {0: 0, 1: 1, 2: 2}\n",
      "[3.84153917e-05 5.25603786e-01 4.74357798e-01] 1 {0: 0, 1: 1, 2: 2}\n",
      "[0.91969448 0.00730851 0.07299701] 0 {0: 0, 1: 1, 2: 2}\n",
      "[3.48493252e-03 9.96377722e-01 1.37345863e-04] 1 {0: 0, 1: 1, 2: 2}\n",
      "[0.00597935 0.02540794 0.96861271] 2 {0: 0, 1: 1, 2: 2}\n",
      "[0.96593729 0.01606867 0.01799404] 0 {0: 0, 1: 1, 2: 2}\n",
      "[7.07324443e-01 2.92672257e-01 3.29961259e-06] 0 {0: 0, 1: 1, 2: 2}\n",
      "[0.96122092 0.03604362 0.00273547] 0 {0: 0, 1: 1, 2: 2}\n",
      "[9.92671813e-01 7.31265179e-03 1.55352641e-05] 0 {0: 0, 1: 1, 2: 2}\n",
      "[9.99997290e-01 2.58555077e-06 1.24081335e-07] 0 {0: 0, 1: 1, 2: 2}\n",
      "[1.77991802e-05 4.62006560e-04 9.99520194e-01] 2 {0: 0, 1: 1, 2: 2}\n",
      "[9.99995176e-01 3.85240188e-06 9.72067357e-07] 0 {0: 0, 1: 1, 2: 2}\n",
      "[0.15306343 0.21405142 0.63288515] 2 {0: 0, 1: 1, 2: 2}\n",
      "[0.25817329 0.28818997 0.45363674] 2 {0: 0, 1: 1, 2: 2}\n",
      "[2.43530473e-04 4.07929999e-01 5.91826471e-01] 2 {0: 0, 1: 1, 2: 2}\n",
      "[0.71160155 0.27290911 0.01548934] 0 {0: 0, 1: 1, 2: 2}\n",
      "[2.94976826e-06 2.51510534e-02 9.74845997e-01] 2 {0: 0, 1: 1, 2: 2}\n",
      "[0.97629163 0.00331591 0.02039245] 0 {0: 0, 1: 1, 2: 2}\n",
      "[0.04513811 0.01484173 0.94002015] 2 {0: 0, 1: 1, 2: 2}\n",
      "[0.61382753 0.38321073 0.00296174] 0 {0: 0, 1: 1, 2: 2}\n",
      "[9.65538451e-01 3.86322918e-06 3.44576854e-02] 0 {0: 0, 1: 1, 2: 2}\n",
      "[0.00924088 0.01731108 0.97344804] 2 {0: 0, 1: 1, 2: 2}\n",
      "[0.02511142 0.93818613 0.03670245] 1 {0: 0, 1: 1, 2: 2}\n",
      "[9.99127831e-01 3.29723254e-04 5.42445518e-04] 0 {0: 0, 1: 1, 2: 2}\n",
      "[0.05081665 0.0038204  0.94536295] 2 {0: 0, 1: 1, 2: 2}\n",
      "[9.99985376e-01 6.85280694e-06 7.77081022e-06] 0 {0: 0, 1: 1, 2: 2}\n",
      "[9.99791732e-01 2.06536005e-04 1.73191035e-06] 0 {0: 0, 1: 1, 2: 2}\n",
      "[2.72323181e-04 2.99692548e-03 9.96730751e-01] 2 {0: 0, 1: 1, 2: 2}\n",
      "[0.02005139 0.97151852 0.00843009] 1 {0: 0, 1: 1, 2: 2}\n",
      "[0.95642409 0.02485912 0.01871679] 0 {0: 0, 1: 1, 2: 2}\n",
      "[0.00297317 0.01261126 0.98441558] 2 {0: 0, 1: 1, 2: 2}\n",
      "0.37777777777777777\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "maxent = MaxEnt()\n",
    "maxent.fit(X_train, y_train)\n",
    "y_pred = maxent.predict(X_test)\n",
    "print(accuracy_score(y_test, y_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
